#!/usr/bin/env python3

import math
from typing import Type

import torch
import torch.nn as nn


class SelfAttention(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        causal: bool,
        num_tokens: int | None = None,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)

        # assertion
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        assert not (
            causal or num_tokens is not None
        ), "Plese 'set num_tokens' to use causal masking attention"

        # args
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = 1 / math.sqrt(self.head_dim)
        self.causal = causal
        ## causal args
        self.num_tokens = num_tokens

        # layer
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        ## causal
        if self.causal:
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(self.num_tokens, self.num_tokens)).view(
                    1, 1, self.num_tokens, self.num_tokens
                ),
            )

    def forward_vanilla(self, x: torch.Tensor) -> torch.Tensor:
        B, n, d = x.size()

        # qkv-projection
        q, k, v = self.qkv(x).split(d, dim=-1)
        q = q.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        k = k.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        v = v.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        # normalize
        q, k = self.q_norm(q), self.k_norm(k)

        # attention-logit
        q = q * self.scale
        attn = torch.einsum("Bhnd,Bhod->Bhno", q, k)

        # softmax
        attn = torch.softmax(attn, -1)
        attn = self.attn_drop(attn)

        # qkv
        x = torch.einsum("Bhno,Bhod->Bhnd", attn, v)

        # projection
        x = x.transpose(1, 2).contiguous().view(B, n, d)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def forward_causal(self, x: torch.Tensor) -> torch.Tensor:
        B, n, d = x.size()

        # qkv-projection
        q, k, v = self.qkv(x).split(d, dim=-1)
        q = q.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        k = k.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        v = v.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        # normalize
        q, k = self.q_norm(q), self.k_norm(k)

        # calc attention-logit
        q = q * self.scale
        attn = torch.einsum("Bhnd,Bhod->Bhno", q, k)

        # attention masking
        attn = attn.masked_fill(self.bias[:, :, :n, :n] == 0, float("-inf"))

        # softmax
        attn = attn.softmax(-1)
        attn = self.attn_drop(attn)

        # qkv
        x = torch.einsum("Bhno,Bhod->Bhnd", attn, v)

        # projection
        x = x.transpose(1, 2).contiguous().view(B, n, d)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        if self.causal:
            x = self.forward_causal(x)
        else:
            x = self.forward_vanilla(x)

        return x
